# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import math
import warnings

import torch
import torch.nn as nn
from classy_vision.generic.util import get_torch_version
from torch import Tensor


def lecun_normal_init(tensor, fan_in):
    if get_torch_version() >= [1, 7]:
        trunc_normal_ = nn.init.trunc_normal_
    else:

        def trunc_normal_(
            tensor: Tensor,
            mean: float = 0.0,
            std: float = 1.0,
            a: float = -2.0,
            b: float = 2.0,
        ) -> Tensor:
            # code copied from https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
            # commit: e9b369c

            # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
            def norm_cdf(x):
                # Computes standard normal cumulative distribution function
                return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

            if (mean < a - 2 * std) or (mean > b + 2 * std):
                warnings.warn(
                    "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                    "The distribution of values may be incorrect.",
                    stacklevel=2,
                )

            with torch.no_grad():
                # Values are generated by using a truncated uniform distribution and
                # then using the inverse CDF for the normal distribution.
                # Get upper and lower cdf values
                l = norm_cdf((a - mean) / std)
                u = norm_cdf((b - mean) / std)

                # Uniformly fill tensor with values from [l, u], then translate to
                # [2l-1, 2u-1].
                tensor.uniform_(2 * l - 1, 2 * u - 1)

                # Use inverse cdf transform for normal distribution to get truncated
                # standard normal
                tensor.erfinv_()

                # Transform to proper mean, std
                tensor.mul_(std * math.sqrt(2.0))
                tensor.add_(mean)

                # Clamp to ensure it's in the proper range
                tensor.clamp_(min=a, max=b)
                return tensor

    trunc_normal_(tensor, std=math.sqrt(1 / fan_in))
